############################################
### Irradiation Brood Viability Analysis ###
############################################

### Packages ###
library(tidyverse)
library(DescTools)
library(reshape2)
library(readxl)
library(rstan)

#Read in data
data <- read_excel(path='Figure 4 source data 3.xlsx', #Input file path to source data here
                   sheet = 'Per-worm Counts') 
#Format data for Stan
stan_data <- data %>%
  mutate(Genotype = factor(Genotype, levels = c('N2',
                                                'brc-1(xoe4)',
                                                'smc-5(ok2421)',
                                                'smc-5(ok2421);brc-1(xoe4)',
                                                "smc-5(ok2421);brc-1(xoe4) MATED",
                                                'smc-5(ok2421);brc-1(xoe4);lig-4(ok716)',
                                                'smc-5(ok2421);brc-1(xoe4);polq-1(tm2572)',
                                                "smc-5(ok2421);brc-1(xoe4);polq-1(tm2572) MATED",
                                                "polq-1(tm2572)", "polq-1(tm2572) MATED"))) %>%
  mutate(Geno_num = as.numeric(Genotype),
         Window_num = ifelse(Window == 'IH Window', 2,
                             ifelse(Window=='Non-IH Window',1,NA)),
         Treat_num = ifelse(Treatment == '0 Rad', 1,
                            ifelse(Treatment=='2500 Rad', 2,
                                   ifelse(Treatment=='5000 Rad', 3, NA))),
         Bsize = Live+Dead
  )

###################
### Rstan Model ###
###################

modelcode <- "data {
  int Genos; //total number of genotypes scored
  int Timepts; //number of Timepts scored
  int Doses; //Irradiation dose treatments;
  int Herms; //Number of hermaphrodite broods scored
  int geno[Herms]; // genotype category of each herm
  int dose[Herms]; //dose of IR administered to each herm
  int timept[Herms]; //timepoint progeny were laid in
  int live[Herms]; //timept at which progeny were scored
  int bsize[Herms]; //number of total living progeny + dead eggs in brood of given hermaphrodite
}
parameters {
  real <lower=0,upper=1> p[Genos,Doses,Timepts]; //p parameter of binomial distribution
  real <lower=0> alpha[Genos,Doses,Timepts]; //alpha shape parameter of beta distribution
  real <lower=0> beta[Genos,Doses,Timepts]; //beta shape parameter of beta distribution
}
transformed parameters {
  real <lower=0,upper=1> phi[Genos,Doses,Timepts]; //phi (alpha / (alpha + beta))
  real <lower=0> lambda[Genos,Doses,Timepts]; //lambda (Beta dist alpha + beta shape params)
  
  for (g in 1:Genos) {
    for (d in 1:Doses) {
      for (s in 1:Timepts) {
        phi[g,d,s] = alpha[g,d,s] / (alpha[g,d,s] + beta[g,d,s]);
        lambda[g,d,s] = alpha[g,d,s] + beta[g,d,s];
      }
    }
  }
}
model {
  for (g in 1:Genos) {
    for (d in 1:Doses) {
      alpha[g,d,] ~ normal(0,50);
      beta[g,d,] ~ normal(0,50);
      for (t in 1:Timepts) {
        p[g,d,t] ~ beta(lambda[g,d,t] * phi[g,d,t], lambda[g,d,t] * (1 - phi[g,d,t]));
      }
    }
  }
  for (h in 1:Herms) {
    live[h] ~ binomial(bsize[h],p[geno[h],dose[h],timept[h]]);
    }
}
generated quantities{
  real gamma[Genos, Doses,Timepts];
  for (d in 1:Doses) {
    for(t in 1:Timepts) {
      for(g in 1:Genos) {
        gamma[g,d,t] = p[g,d,t]/p[g,1,t]; //p parameter of binomial normalized to p of unirradiated treatment - this is the 'sensitivity' to IR of a given genotype
      }
    }
  }
}"


model <- stan_model(model_code=modelcode)

##############################
### Fit Stan model to data ###
##############################

set.seed(666)
stanrun <- sampling(model,data=list(Genos = stan_data$Geno_num %>% unique() %>% length(),
                                    Timepts = 2,
                                    Doses = 3,
                                    Herms = nrow(stan_data),
                                    geno=stan_data$Geno_num,
                                    dose=stan_data$Treat_num,
                                    timept=as.integer(stan_data$Window_num),
                                    live=stan_data$Live,
                                    bsize=stan_data$Bsize), 
                    iter=1e4, control = list(adapt_delta = 0.8))

### Pull out parameter estimate summary values to plot
stanrun_summ <- as_tibble(summary(stanrun)$summary, rownames = c('Parameter')) %>%
  separate(., col=Parameter, into=c('Parameter','GenotypeID','Treatment','Timepoint')) %>%
  mutate(low95CI = .$`2.5%`,
         up95CI = .$`97.5%`)

genotypes <- factor(c('N2',
                      'brc-1(xoe4)',
                      'smc-5(ok2421)',
                      'smc-5(ok2421);brc-1(xoe4)',
                      "smc-5(ok2421);brc-1(xoe4) MATED",
                      'smc-5(ok2421);brc-1(xoe4);lig-4(ok716)',
                      'smc-5(ok2421);brc-1(xoe4);polq-1(tm2572)',
                      "smc-5(ok2421);brc-1(xoe4);polq-1(tm2572) MATED",
                      "polq-1(tm2572)", "polq-1(tm2572) MATED"),
                    levels = c('N2',
                               'brc-1(xoe4)',
                               'smc-5(ok2421)',
                               'smc-5(ok2421);brc-1(xoe4)',
                               "smc-5(ok2421);brc-1(xoe4) MATED",
                               'smc-5(ok2421);brc-1(xoe4);lig-4(ok716)',
                               'smc-5(ok2421);brc-1(xoe4);polq-1(tm2572)',
                               "smc-5(ok2421);brc-1(xoe4);polq-1(tm2572) MATED",
                               "polq-1(tm2572)", "polq-1(tm2572) MATED"))

treatments <- factor(c('0 Rads','2500 Rads','5000 Rads'),
                     levels = c('0 Rads','2500 Rads','5000 Rads'))

timepoints <- factor(c('Non-IH Window','IH Window'),
                     levels = c('IH Window','Non-IH Window'))


# Plot gamma values
stanrun_summ %>%
  filter(str_detect(.$Parameter, 'gamma')) %>%
  filter(Treatment != '1') %>%
  mutate(GenotypeID = as.numeric(GenotypeID),
    Genotype = genotypes[GenotypeID],
    Timepoint = as.numeric(Timepoint),
    Timepoint = timepoints[Timepoint],
    Treatment = as.numeric(Treatment),
    Treatment = as.numeric(Treatment),
    Treatment = treatments[Treatment]
    ) %>%
  ggplot() +
  geom_errorbar(aes(x=Genotype, ymin=low95CI, ymax=up95CI, color=Genotype), width=0.2, size=1) +
  facet_grid(Treatment~Timepoint) +
  theme_bw() +
  theme(panel.grid.minor = element_blank(),
        legend.position = 'none',
        axis.text.x = element_text(angle = 60, vjust = 1, hjust=1)) +
  scale_y_continuous(limits=c(0,1.5),breaks=c(seq(0,1.5,.25)), expand=c(0,0)) +
  ylab('Gamma value')

################################################
### Assess model fit - posterior simulations ###
################################################

#Extract phi value
phi <- rstan::extract(stanrun)$phi %>%
  melt() %>%
  as_tibble() %>%
  mutate(GenotypeID = as.numeric(Var2),
         TreatmentID = as.numeric(Var3),
         TimeptID = as.numeric(Var4),
         Genotype = genotypes[GenotypeID],
         Treatment = ifelse(Var3 == 1, '0 Rad',
                            ifelse (Var3==2, '2500 Rad',
                                    ifelse(Var3==3,'5000 Rad','ERROR'))),
         Timept = ifelse(Var4==1, 'Non-IH Window',
                         ifelse(Var4==2,'IH Window','ERROR'))) %>%
  select(-iterations,-Var2,-Var3,-Var4)

#Extract lambda value
lambda <- rstan::extract(stanrun)$lambda %>%
  melt() %>%
  as_tibble() %>%
  mutate(GenotypeID = as.numeric(Var2),
         TreatmentID = as.numeric(Var3),
         TimeptID = as.numeric(Var4),
         Genotype = genotypes[GenotypeID],
         Treatment = ifelse(Var3 == 1, '0 Rad',
                            ifelse (Var3==2, '2500 Rad',
                                    ifelse(Var3==3,'5000 Rad','ERROR'))),
         Timept = ifelse(Var4==1, 'Non-IH Window',
                         ifelse(Var4==2,'IH Window','ERROR'))) %>%
  select(-iterations,-Var2,-Var3,-Var4)


genos <- 10
treats <- 3
timepts <- 2


for(g in 1:genos) {
  if(g==1) {set.seed(12345)}
  for(t in 1:treats) {
    for(s in 1:timepts) {
      
      bsize_sim <- rpois(1500, stan_data %>% 
                           filter(Geno_num == g & Treat_num == t & Window_num == s) %>%
                           .$Bsize %>% mean())
      
      phi_sample <- sample(phi %>% 
                             filter(GenotypeID == g & TreatmentID == t & TimeptID == s)%>%
                             .$value, 
                           size=1500, replace=TRUE)
      lambda_sample <- sample(lambda %>% 
                                filter(GenotypeID == g & TreatmentID == t & TimeptID == s)%>%
                                .$value, 
                              size=1500, replace=TRUE)
      
      alpha_calc <- phi_sample * lambda_sample
      
      beta_calc <- lambda_sample * (1-phi_sample)
      
      p_sim <- rbeta(1500, alpha_calc, beta_calc)
      
      live_sim <- rbinom(1500, size=bsize_sim, prob=p_sim)
      
      if (g == 1 & t == 1 & s == 1) {
        postsim_dat <- as.data.frame(p_sim) %>%
          as_tibble() %>%
          bind_cols(., 
                    as.data.frame(bsize_sim) %>%
                      as_tibble()) %>%
          bind_cols(.,
                    as.data.frame(live_sim) %>%
                      as_tibble()) %>%
          mutate(GenotypeID = g,
                 TreatmentID = t,
                 TimeptID = s,
                 Iter = rep_len(seq(1,100),1500))
      } else {
        postsim_dat <- bind_rows(postsim_dat, 
                                 as.data.frame(p_sim) %>%
                                   as_tibble() %>%
                                   bind_cols(., 
                                             as.data.frame(bsize_sim) %>%
                                               as_tibble()) %>%
                                   bind_cols(.,
                                             as.data.frame(live_sim) %>%
                                               as_tibble()) %>%
                                   mutate(GenotypeID = g,
                                          TreatmentID = t,
                                          TimeptID = s,
                                          Iter = rep_len(seq(1,100),1500))
        )
      }
    }
  }
}

#Plot posterior simulations - violins (White is empirical, colorful is simulated)
postsim_dat %>%
  mutate(Genotype = genotypes[GenotypeID],
         GenotypeID = as.numeric(GenotypeID),
         bviable = live_sim/bsize_sim,
         TimeptID = factor(TimeptID, levels=c(2,1))) %>%
  bind_rows(.,
            stan_data %>%
              mutate(GenotypeID = 11,
                     bviable = Live/Bsize,
                     TreatmentID = Treat_num,
                     TimeptID = Window_num,
                     TimeptID = factor(TimeptID, levels=c(2,1)))) %>%
  ggplot() +
  geom_violin(aes(x=Genotype, y=bviable, #color=as.factor(GenotypeID), 
                  fill=as.factor(GenotypeID)),
              alpha=0.5,scale='area',position='identity',size=1) +
  #scale_color_manual(values=c('none','red','blue','purple','purple','green','pink','pink','purple','purple','black')) +
  scale_fill_manual(values=c('grey','red','blue','purple','purple','green','pink','pink','purple','purple','white'))+
  scale_alpha_manual(values = c(rep(1,10),1) )+
  facet_grid(TreatmentID~TimeptID)+
  theme_bw() +
  theme(panel.grid = element_blank(),
        legend.position = 'none',
        axis.text.x = element_text(angle = 60, vjust = 1, hjust=1)) +
  ylab('Brood Viability')
